{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nH85BOCo7YYk"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "9tQNAByc7U9g"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F7r2q0wS7bxf"
      },
      "source": [
        "# Play with AI - Guess the word\n",
        "\n",
        "This cookbook illustrates how you can employ the instruction-tuned model version of Gemma as a chatbot to play \"Guess the word\" game.\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Guess_the_word.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZHrL4tqs7mYK"
      },
      "source": [
        "## Setup\n",
        "\n",
        "### Select the Colab runtime\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n",
        "\n",
        "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n",
        "2. Select **Change runtime type**.\n",
        "3. Under **Hardware accelerator**, select **T4 GPU**.\n",
        "\n",
        "\n",
        "### Gemma setup on Kaggle\n",
        "To complete this tutorial, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n",
        "\n",
        "* Get access to Gemma on kaggle.com.\n",
        "* Select a Colab runtime with sufficient resources to run the Gemma 2B model.\n",
        "* Generate and configure a Kaggle username and API key.\n",
        "\n",
        "After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pQEE8RoO75F-"
      },
      "source": [
        "### Set environment variables\n",
        "\n",
        "Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "XsY2Ut7a76Wa"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "\n",
        "os.environ[\"KERAS_BACKEND\"] = \"jax\"  # Or \"tensorflow\" or \"torch\".\n",
        "\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate for your system.\n",
        "os.environ[\"KAGGLE_USERNAME\"] = userdata.get(\"KAGGLE_USERNAME\")\n",
        "os.environ[\"KAGGLE_KEY\"] = userdata.get(\"KAGGLE_KEY\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ea_56Zpa78Gu"
      },
      "source": [
        "### Install dependencies\n",
        "\n",
        "Install Keras and KerasNLP."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AxPjbcnC79ck"
      },
      "outputs": [],
      "source": [
        "# Install Keras 3 last. See https://keras.io/getting_started/ for more details.\n",
        "!pip install -q -U keras-nlp\n",
        "!pip install -q -U keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a_QCPQLf8OU0"
      },
      "source": [
        "### Create a chat helper to manage the conversation state"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2BmB5Zua8Vs0"
      },
      "outputs": [],
      "source": [
        "import re\n",
        "\n",
        "import keras\n",
        "import keras_nlp\n",
        "\n",
        "# Run at half precision to fit in memory\n",
        "keras.config.set_floatx(\"bfloat16\")\n",
        "\n",
        "gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma2_instruct_2b_en\")\n",
        "gemma_lm.compile(sampler=\"top_k\")\n",
        "\n",
        "\n",
        "class ChatState():\n",
        "  \"\"\"\n",
        "  Manages the conversation history for a turn-based chatbot\n",
        "  Follows the turn-based conversation guidelines for the Gemma family of models\n",
        "  documented at https://ai.google.dev/gemma/docs/formatting\n",
        "  \"\"\"\n",
        "\n",
        "  __START_TURN_USER__ = \"<start_of_turn>user\\n\"\n",
        "  __START_TURN_MODEL__ = \"<start_of_turn>model\\n\"\n",
        "  __END_TURN__ = \"<end_of_turn>\\n\"\n",
        "\n",
        "  def __init__(self, model, system=\"\"):\n",
        "    \"\"\"\n",
        "    Initializes the chat state.\n",
        "\n",
        "    Args:\n",
        "        model: The language model to use for generating responses.\n",
        "        system: (Optional) System instructions or bot description.\n",
        "    \"\"\"\n",
        "    self.model = model\n",
        "    self.system = system\n",
        "    self.history = []\n",
        "\n",
        "  def add_to_history_as_user(self, message):\n",
        "      \"\"\"\n",
        "      Adds a user message to the history with start/end turn markers.\n",
        "      \"\"\"\n",
        "      self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)\n",
        "\n",
        "  def add_to_history_as_model(self, message):\n",
        "      \"\"\"\n",
        "      Adds a model response to the history with start/end turn markers.\n",
        "      \"\"\"\n",
        "      self.history.append(self.__START_TURN_MODEL__ + message)\n",
        "\n",
        "  def get_history(self):\n",
        "      \"\"\"\n",
        "      Returns the entire chat history as a single string.\n",
        "      \"\"\"\n",
        "      return \"\".join([*self.history])\n",
        "\n",
        "  def get_full_prompt(self):\n",
        "    \"\"\"\n",
        "    Builds the prompt for the language model, including history and system description.\n",
        "    \"\"\"\n",
        "    prompt = self.get_history() + self.__START_TURN_MODEL__\n",
        "    if len(self.system)>0:\n",
        "      prompt = self.system + \"\\n\" + prompt\n",
        "    return prompt\n",
        "\n",
        "  def send_message(self, message):\n",
        "    \"\"\"\n",
        "    Handles sending a user message and getting a model response.\n",
        "\n",
        "    Args:\n",
        "        message: The user's message.\n",
        "\n",
        "    Returns:\n",
        "        The model's response.\n",
        "    \"\"\"\n",
        "    self.add_to_history_as_user(message)\n",
        "    prompt = self.get_full_prompt()\n",
        "    response = self.model.generate(prompt, max_length=4096)\n",
        "    result = response.replace(prompt, \"\")  # Extract only the new response\n",
        "    self.add_to_history_as_model(result)\n",
        "    return result\n",
        "\n",
        "  def show_history(self):\n",
        "    for h in self.history:\n",
        "      print(h)\n",
        "\n",
        "\n",
        "chat = ChatState(gemma_lm)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_1jyCoRd8EwX"
      },
      "source": [
        "## Play the game"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "zoWDt87V83rZ"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Choose your theme: animal\n",
            "Guess what I'm thinking.\n",
            "Type \"quit\" if you want to quit.\n",
            "A playful, furry swimmer. Known for its playful antics and clever use of tools.  A member of the weasel-like family with a distinctive, waterproof coat. \n",
            "<end_of_turn>\n",
            "\n",
            "> platypus\n",
            "A creature that spends its days by a river or lake, often seen splashing and diving with its distinctive, thick fur.  It's known for its playful demeanor and ability to hold a surprisingly large amount of water in its paws. \n",
            "<end_of_turn>\n",
            "\n",
            "> beaver\n",
            "A small, aquatic mammal with a playful, curious nature, often spotted near water, known for its distinctive, waterproof fur and love for swimming. \n",
            "<end_of_turn>\n",
            "\n",
            "> otter\n",
            "Correct!\n"
          ]
        }
      ],
      "source": [
        "theme = input(\"Choose your theme: \")\n",
        "setup_message = f\"Generate a random single word from {theme}.\"\n",
        "\n",
        "chat.history.clear()\n",
        "answer = chat.send_message(setup_message).split()[0]\n",
        "answer = re.sub(r\"\\W+\", \"\", answer)  # excludes all numbers, letters and '_'\n",
        "chat.history.clear()\n",
        "cmd_exit = \"quit\"\n",
        "question = f'Describe the word \"{answer}\" without saying it.'\n",
        "\n",
        "resp = \"\"\n",
        "while resp.lower() != answer.lower() and resp != cmd_exit:\n",
        "    text = chat.send_message(question)\n",
        "    if resp == \"\":\n",
        "        print(f'Guess what I\\'m thinking.\\nType \"{cmd_exit}\" if you want to quit.')\n",
        "    remove_answer = re.compile(re.escape(answer), re.IGNORECASE)\n",
        "    text = remove_answer.sub(\"XXXX\", text)\n",
        "    print(text)\n",
        "    resp = input(\"\\n> \")\n",
        "\n",
        "if resp == cmd_exit:\n",
        "    print(f\"The answer was {answer}.\\n\")\n",
        "else:\n",
        "    print(\"Correct!\")"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "[Gemma_2]Guess_the_word.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
